import tensorflow as tf
tf = tf.compat.v1
def stats_graph(graph):
flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation())
params = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter())
f = open('log.txt','w')
print('GFLOPs: {}; Trainable params: {}'.format(flops.total_float_ops / 1000000000.0, params.total_parameters), file=f)
f.close()
input_saved_model_dir = "./1719407478/"
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, ["serve"], input_saved_model_dir)
graph = tf.get_default_graph()
stats_graph(graph)